using Statistics;
import Optim;

include("../expfam.jl");

# Bounded family in [0,1]
struct BernoulliGen
    μ;
    B;
    generator;
    BernoulliGen(μ, B) = new(μ, B, Bernoulli());
end

sample(rng, dist::BernoulliGen) = dist.B * sample(rng, dist.generator, dist.μ);

struct HistogramGen
    μ;
    histogram;
    B;
    HistogramGen(μ, histogram, B) = new(μ, histogram, B);
end

sample(rng, dist::HistogramGen) = dist.histogram[rand(rng, 1:length(dist.histogram), 1)[1]];

rel_entr(x, y) = x == 0 ? 0. : x * log(x / y);
function Kinf_emp(_isBer, samples, μ, u, B, is_Kinfp)
    @assert u <= B && u >= 0 "Domain violation for u: $(u) ∈ [0, $(B)]"
    if _isBer
        if is_Kinfp
            if u <= μ
                return 0;
            else
                return max(0, rel_entr(μ, u) + rel_entr(1 - μ, 1 - u));
            end
        else
            if u >= μ
                return 0;
            else
                return max(0, rel_entr(μ, u) + rel_entr(1 - μ, 1 - u));
            end
        end
    else
        if is_Kinfp
            if u <= μ
                return 0;
            elseif maximum(samples) < B && mean((B - u) ./ (B .- samples)) <= 1
                return mean(log.((B .- samples) / (B - u)));
            end
        else
            if u >= μ
                return 0;
            elseif minimum(samples) > 0 && mean(u ./ samples) <= 1
                return mean(log.(samples / u));
            end
        end

        try
            Y = is_Kinfp ? (samples .- u) / (B - u) : (u .- samples) / u;
            res = Optim.optimize(x -> -mean(log.(1 .- x * Y)),
                                 0., 1.);

            if Optim.converged(res)
                -Optim.minimum(res);
            else
                println("Optimization failed");
                Inf;
            end
        catch e
            println(e);
            Inf;
        end
    end
end

Kinf_dn(_isBer, samples, μ, v, B) = μ == 0 ? 0. : binary_search(x -> v - Kinf_emp(_isBer, samples, μ, x, B, false), 0, μ);
Kinf_up(_isBer, samples, μ, v, B) = μ == B ? B : binary_search(x -> Kinf_emp(_isBer, samples, μ, x, B, true) - v, μ, B);
